"""Generate plots for Simulation B.

This script reads the summary CSV produced by the driver and generates
figures comparing the measured visibilities with the theoretical
predictions for both Gaussian and uniform jitter laws.  It also
visualises the effect of each ablation by overlaying the degraded
curves on the baseline.
"""

from __future__ import annotations

import argparse
from pathlib import Path
import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from .metrics import bootstrap_ci, rmse


def charfun_gaussian(sigma: np.ndarray) -> np.ndarray:
    return np.exp(-0.5 * sigma ** 2)


def charfun_uniform(a: np.ndarray) -> np.ndarray:
    # handle a=0 separately to avoid division by zero
    res = np.ones_like(a)
    mask = a != 0.0
    res[mask] = np.abs(np.sin(a[mask])) / a[mask]
    return res


def plot_family(df: pd.DataFrame, law: str, out_dir: Path) -> None:
    # Filter rows for the given law
    sub = df[df["law"] == law].copy()
    if sub.empty:
        return
    # Identify param values
    param_values = np.sort(sub["param_value"].unique())
    # Compute median and CI across seeds for each param
    medians = []
    lows = []
    highs = []
    for p in param_values:
        vals = sub[sub["param_value"] == p]["V"]
        median = vals.median()
        low, high = bootstrap_ci(vals)
        medians.append(median)
        lows.append(low)
        highs.append(high)
    medians = np.array(medians)
    lows = np.array(lows)
    highs = np.array(highs)
    # Predicted charfun
    x = param_values
    if law == "gaussian":
        pred = charfun_gaussian(x)
        xlabel = "σ (radians)"
        title = "Gaussian jitter: visibility vs σ"
        fname = "simB_gaussian.png"
    else:
        pred = charfun_uniform(x)
        xlabel = "a (radians)"
        title = "Uniform jitter: visibility vs a"
        fname = "simB_uniform.png"
    # Plot
    fig, ax = plt.subplots(figsize=(7, 4))
    ax.plot(x, pred, color="black", linestyle="--", label="Prediction")
    ax.errorbar(x, medians, yerr=[medians - lows, highs - medians], fmt="o", capsize=5, label="Measured (median ±68% CI)")
    # scatter individual seeds for reference
    # Use a small jitter on x-axis to separate seeds visually
    for seed in sub["seed"].unique():
        vals = sub[sub["seed"] == seed]
        # jitter x positions
        jitter = (seed - sub["seed"].unique().mean()) * 0.005
        ax.scatter(vals["param_value"] + jitter, vals["V"], alpha=0.3, s=15)
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Visibility V")
    ax.set_title(title)
    ax.set_ylim(0, 1.05)
    ax.grid(True, ls="--", alpha=0.5)
    ax.legend()
    fig.tight_layout()
    out_path = out_dir / fname
    fig.savefig(out_path, dpi=150)
    plt.close(fig)


def plot_ablation(df_main: pd.DataFrame, df_abl: pd.DataFrame, law: str, out_dir: Path) -> None:
    # Filter baseline and ablation for this law
    base = df_main[df_main["law"] == law].copy()
    if base.empty:
        return
    ab = df_abl[df_abl["law"] == law].copy()
    # Param values
    param_values = np.sort(base["param_value"].unique())
    # Compute median baseline
    base_med = base.groupby("param_value")["V"].median().reindex(param_values)
    # Predictions
    x = param_values
    if law == "gaussian":
        pred = charfun_gaussian(x)
        fname = "simB_gaussian_ablation.png"
        xlabel = "σ (radians)"
        title = "Gaussian jitter ablations"
    else:
        pred = charfun_uniform(x)
        fname = "simB_uniform_ablation.png"
        xlabel = "a (radians)"
        title = "Uniform jitter ablations"
    fig, ax = plt.subplots(figsize=(7, 4))
    # Draw prediction
    ax.plot(x, pred, color="black", linestyle="--", label="Prediction")
    # Baseline
    ax.plot(x, base_med.values, marker="o", label="Baseline")
    # Each ablation
    for ab_name in ab["ablation"].unique():
        med = ab[ab["ablation"] == ab_name].groupby("param_value")["V"].median().reindex(param_values)
        ax.plot(x, med.values, marker="o", linestyle="--", label=ab_name)
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Visibility V")
    ax.set_title(title)
    ax.set_ylim(0, 1.05)
    ax.grid(True, ls="--", alpha=0.5)
    ax.legend()
    fig.tight_layout()
    out_path = out_dir / fname
    fig.savefig(out_path, dpi=150)
    plt.close(fig)


def main(summary_path: str, output_dir: str, ablation_path: str | None = None) -> None:
    out_dir = Path(output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    df = pd.read_csv(summary_path)
    # Plot baseline for gaussian and uniform laws
    for law in df["law"].unique():
        plot_family(df, law, out_dir)
    if ablation_path:
        df_ab = pd.read_csv(ablation_path)
        for law in df["law"].unique():
            plot_ablation(df, df_ab, law, out_dir)
    print(f"Figures saved to {out_dir}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate plots for Simulation B")
    parser.add_argument("--summary", type=str, required=True, help="Path to summary CSV")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save figures")
    parser.add_argument("--ablation", type=str, default=None, help="Path to ablation CSV (optional)")
    args = parser.parse_args()
    main(args.summary, args.output_dir, args.ablation)